Skip to content

Conversation

@Max191
Copy link
Contributor

@Max191 Max191 commented Nov 20, 2025

The existing FoldPadWithProducerReshapeOpByExpansion and FoldPadWithProducerReshapeOpByCollapsing patterns did not cover all reshape propagation cases, because they only consider cases where the pad op is the consumer operation. This PR adds 2 new patterns to cover the cases where the pad op is the producer operation, which completes the propagation pattern set for pad op with expand_shape and collapse_shape.

Note for integration: This PR also removes the single user restriction for the FoldPadWithProducerReshapeOpByExpansion and FoldPadWithProducerReshapeOpByCollapsing patterns, which leaves more control to the users of the pattern. If this constraint is needed, then it should be added to the control function for these patterns.

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

The existing FoldPadWithProducerReshapeOpByExpansion and FoldPadWithProducerReshapeOpByCollapsing patterns did not cover all reshape propagation cases, because they only consider cases where the pad op is the consumer operation. This PR adds 2 new patterns to cover the cases where the pad op is the producer operation, which completes the propagation pattern set for pad op with expand_shape and collapse_shape.


Patch is 21.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168888.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+234-49)
  • (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+39)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+41)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05fc7cbbb90af..8c5a0c1474408 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1038,6 +1038,54 @@ class FoldWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Carries information about a padded dimension.
+struct PadDimInfo {
+  // The resulting shape after padding each dimension.
+  SmallVector<int64_t> paddedShape;
+
+  // Low and high padding amounts for each dimension.
+  SmallVector<OpFoldResult> lowPad;
+  SmallVector<OpFoldResult> highPad;
+};
+
+/// Computes the expanded padding information for the given pad operation based
+/// on the provided expanded shape and reassociation indices. Returns a list of
+/// PaddedDimInfo containing the low and high padding amounts and the padded
+/// size for each dimension, or failure if the expansion is not possible.
+static FailureOr<PadDimInfo>
+computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
+                       ArrayRef<ReassociationIndices> reassociations,
+                       PatternRewriter &rewriter) {
+  ArrayRef<int64_t> low = padOp.getStaticLow();
+  ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+  // Expanded dimensions cannot have padding because the resulting padding may
+  // not be representable by a tensor.pad op. There are some special cases where
+  // it is possible (like expanding unit dims), but supporting these cases is
+  // NYI, so disallow it for now.
+  for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+    if (reInd.size() != 1 && (l != 0 || h != 0))
+      return failure();
+  }
+
+  SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+  SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
+  ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
+  PadDimInfo padDimInfo;
+  padDimInfo.paddedShape.assign(expandedShape);
+  padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+  padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+    if (reInd.size() == 1) {
+      padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
+      padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
+      padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
+    }
+  }
+
+  return padDimInfo;
+}
+
 class FoldPadWithProducerReshapeOpByExpansion
     : public OpRewritePattern<tensor::PadOp> {
 public:
@@ -1061,38 +1109,92 @@ class FoldPadWithProducerReshapeOpByExpansion
                                          "fusion blocked by control function");
     }
 
-    ArrayRef<int64_t> low = padOp.getStaticLow();
-    ArrayRef<int64_t> high = padOp.getStaticHigh();
+    RankedTensorType expandedType = reshapeOp.getSrcType();
     SmallVector<ReassociationIndices> reassociations =
         reshapeOp.getReassociationIndices();
+    FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+        padOp, expandedType.getShape(), reassociations, rewriter);
+    if (failed(maybeExpandedPadding))
+      return failure();
+    PadDimInfo expandedPadding = maybeExpandedPadding.value();
 
-    for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
-      if (reInd.size() != 1 && (l != 0 || h != 0))
-        return failure();
+    Location loc = padOp->getLoc();
+    RankedTensorType expandedPaddedType =
+        padOp.getResultType().clone(expandedPadding.paddedShape);
+
+    auto newPadOp = tensor::PadOp::create(
+        rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
+        expandedPadding.lowPad, expandedPadding.highPad,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
+class FoldExpandShapeWithProducerPadOp
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+public:
+  FoldExpandShapeWithProducerPadOp(MLIRContext *context,
+                                   ControlFusionFn foldReshapes,
+                                   PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "fusion blocked by control function");
     }
 
-    SmallVector<OpFoldResult> newLow, newHigh;
-    RankedTensorType expandedType = reshapeOp.getSrcType();
-    RankedTensorType paddedType = padOp.getResultType();
-    SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+    RankedTensorType expandedType = expandOp.getResultType();
+    SmallVector<ReassociationIndices> reassociations =
+        expandOp.getReassociationIndices();
+    FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+        padOp, expandedType.getShape(), reassociations, rewriter);
+    if (failed(maybeExpandedPadding))
+      return failure();
+    PadDimInfo expandedPadding = maybeExpandedPadding.value();
+
+    Location loc = expandOp->getLoc();
+    SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
+    SmallVector<int64_t> newExpandedShape(expandedType.getShape());
+    rewriter.setInsertionPointAfterValue(padOp.getSource());
+    SmallVector<OpFoldResult> padSrcSizes =
+        tensor::getMixedSizes(rewriter, loc, padOp.getSource());
     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      // We know that any reassociation with multiple dims is not padded because
+      // of the requirements of computeExpandedPadding.
       if (reInd.size() == 1) {
-        expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
-      }
-      for (size_t i = 0; i < reInd.size(); ++i) {
-        newLow.push_back(padOp.getMixedLowPad()[idx]);
-        newHigh.push_back(padOp.getMixedHighPad()[idx]);
+        newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
+        newExpandedSizes[reInd[0]] = padSrcSizes[idx];
       }
     }
-
-    Location loc = padOp->getLoc();
-    RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+    RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
+    auto newExpandOp = tensor::ExpandShapeOp::create(
+        rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
+        newExpandedSizes);
+    RankedTensorType expandedPaddedType =
+        padOp.getResultType().clone(expandedPadding.paddedShape);
+    rewriter.setInsertionPoint(expandOp);
     auto newPadOp = tensor::PadOp::create(
-        rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+        rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
+        expandedPadding.lowPad, expandedPadding.highPad,
         padOp.getConstantPaddingValue(), padOp.getNofold());
 
-    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
-        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+    rewriter.replaceOp(expandOp, newPadOp.getResult());
 
     return success();
   }
@@ -1921,6 +2023,52 @@ struct FoldReshapeWithGenericOpByCollapsing
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Computes the collapsed padding information for the given pad operation based
+/// on the provided collapsed shape and reassociation indices. Returns a
+/// PadDimInfo containing the low and high padding amounts and the collapsed
+/// shape for each dimension, or failure if the collapse is not possible.
+static FailureOr<PadDimInfo>
+computeCollapsedPadding(tensor::PadOp padOp,
+                        ArrayRef<ReassociationIndices> reassociations,
+                        PatternRewriter &rewriter) {
+  ArrayRef<int64_t> low = padOp.getStaticLow();
+  ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+  // Collapsed dimensions cannot have padding because this can produce strided
+  // padding that isn't representable by a tensor.pad op. There are some special
+  // cases where it it possible (like collapsing unit dims), but supporting
+  // these cases is NYI, so disallow it for now.
+  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+    for (int64_t dim : reInd) {
+      if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
+        return failure();
+    }
+  }
+
+  // Initialize padding values for collapsed tensors with zeros
+  ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
+  PadDimInfo padDimInfo;
+  padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+  padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+
+  // Update padding for dimensions that are not being collapsed, and compute
+  // the collapsed padded shape.
+  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+    if (reInd.size() == 1) {
+      padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]];
+      padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]];
+    }
+    SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
+    for (int64_t dim : reInd) {
+      collapsedSize =
+          collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
+    }
+    padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
+  }
+
+  return padDimInfo;
+}
+
 class FoldPadWithProducerReshapeOpByCollapsing
     : public OpRewritePattern<tensor::PadOp> {
 public:
@@ -1944,49 +2092,34 @@ class FoldPadWithProducerReshapeOpByCollapsing
                                          "fusion blocked by control function");
     }
 
-    ArrayRef<int64_t> low = padOp.getStaticLow();
-    ArrayRef<int64_t> high = padOp.getStaticHigh();
     SmallVector<ReassociationIndices> reassociations =
         reshapeOp.getReassociationIndices();
+    FailureOr<PadDimInfo> maybeCollapsedPadding =
+        computeCollapsedPadding(padOp, reassociations, rewriter);
+    if (failed(maybeCollapsedPadding))
+      return failure();
+    PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
 
-    for (auto reInd : reassociations) {
-      if (reInd.size() == 1)
-        continue;
-      if (llvm::any_of(reInd, [&](int64_t ind) {
-            return low[ind] != 0 || high[ind] != 0;
-          })) {
-        return failure();
-      }
-    }
-
-    SmallVector<OpFoldResult> newLow, newHigh;
-    RankedTensorType collapsedType = reshapeOp.getSrcType();
-    RankedTensorType paddedType = padOp.getResultType();
-    SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
-    SmallVector<OpFoldResult> expandedPaddedSizes(
-        getMixedValues(reshapeOp.getStaticOutputShape(),
-                       reshapeOp.getOutputShape(), rewriter));
+    SmallVector<OpFoldResult> expandedPaddedSizes =
+        reshapeOp.getMixedOutputShape();
     AffineExpr d0, d1, d2;
     bindDims(rewriter.getContext(), d0, d1, d2);
     auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
     Location loc = reshapeOp->getLoc();
-    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
-      OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
-      OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
+    for (auto [reInd, l, h] :
+         llvm::zip_equal(reassociations, collapsedPadding.lowPad,
+                         collapsedPadding.highPad)) {
       if (reInd.size() == 1) {
-        collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
-        OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+        expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
             rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
-        expandedPaddedSizes[reInd[0]] = paddedSize;
       }
-      newLow.push_back(l);
-      newHigh.push_back(h);
     }
 
     RankedTensorType collapsedPaddedType =
-        paddedType.clone(collapsedPaddedShape);
+        padOp.getType().clone(collapsedPadding.paddedShape);
     auto newPadOp = tensor::PadOp::create(
-        rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+        rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
+        collapsedPadding.lowPad, collapsedPadding.highPad,
         padOp.getConstantPaddingValue(), padOp.getNofold());
 
     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -2000,6 +2133,54 @@ class FoldPadWithProducerReshapeOpByCollapsing
   ControlFusionFn controlFoldingReshapes;
 };
 
+class FoldReshapeWithProducerPadOpByCollapsing
+    : public OpRewritePattern<tensor::CollapseShapeOp> {
+public:
+  FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
+                                           ControlFusionFn foldReshapes,
+                                           PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(padOp,
+                                         "fusion blocked by control function");
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        reshapeOp.getReassociationIndices();
+    RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
+    FailureOr<PadDimInfo> maybeCollapsedPadding =
+        computeCollapsedPadding(padOp, reassociations, rewriter);
+    if (failed(maybeCollapsedPadding))
+      return failure();
+    PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
+
+    Location loc = reshapeOp->getLoc();
+    auto newCollapseOp = tensor::CollapseShapeOp::create(
+        rewriter, loc, padOp.getSource(), reassociations);
+
+    auto newPadOp = tensor::PadOp::create(
+        rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
+        collapsedPadding.lowPad, collapsedPadding.highPad,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOp(reshapeOp, newPadOp.getResult());
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to collapse dimensions.
 template <typename LinalgType>
 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -2239,6 +2420,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                     controlFoldingReshapes);
   patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                         controlFoldingReshapes);
+  patterns.add<FoldExpandShapeWithProducerPadOp>(patterns.getContext(),
+                                                 controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
@@ -2250,6 +2433,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
                                                       controlFoldingReshapes);
   patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
       patterns.getContext(), controlFoldingReshapes);
+  patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+      patterns.getContext(), controlFoldingReshapes);
   patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 2bf3d21c35526..923bb2ca9c28a 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -639,6 +639,45 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
 // CHECK-SAME:       output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
 //      CHECK:   return %[[EXPAND]]
 
+// -----
+
+func.func @collapse_shape_with_producer_pad(%arg0: tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
+  %cst = arith.constant 0 : i32
+  %padded = tensor.pad %arg0 low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+       %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+    tensor.yield %cst : i32
+  } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+  %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]]
+    : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
+  return %collapsed : tensor<8x12x17x336x14xi32>
+}
+//      CHECK: func @collapse_shape_with_producer_pad
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
+//      CHECK:   return %[[PAD]]
+
+// -----
+
+func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?x?x?xf32>,
+    %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+  %padded = tensor.pad %arg0 low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
+  %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5]]
+    : tensor<?x?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %collapsed : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @collapse_shape_with_producer_pad_dynamic
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?x?x?xf32>
+// CHECK-SAME:   %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]]
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
+//      CHECK:   return %[[PAD]]
+
 // -----
 // Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
 #map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..f6572674d10e2 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -863,6 +863,47 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
 
 // -----
 
+func.func @expand_shape_with_producer_pad(%arg0: tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
+  %cst = arith.constant 0 : i32
+  %padded = tensor.pad %arg0 low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %cst : i32
+  } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+  %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [8, 3, 4, 17, 6, 7, 8, 14]
+    : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
+  return %expanded : tensor<8x3x4x17x6x7x8x14xi32>
+}
+//      CHECK: func @expand_shape_with_producer_pad
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\...
[truncated]

@github-actions
Copy link

github-actions bot commented Nov 20, 2025

🐧 Linux x64 Test Results

  • 7159 tests passed
  • 595 tests skipped

@Max191 Max191 requested a review from IanWood1 November 20, 2025 22:24
Copy link
Contributor

@IanWood1 IanWood1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Max191 Max191 force-pushed the add-missing-pad-reshape-propagation branch from 770eaa3 to d1b6dac Compare December 1, 2025 17:12
Signed-off-by: Max Dawkins <[email protected]>
@Max191 Max191 force-pushed the add-missing-pad-reshape-propagation branch from d1b6dac to 3594ff0 Compare December 1, 2025 20:11
@Max191 Max191 merged commit af2e246 into llvm:main Dec 1, 2025
10 checks passed
@Max191 Max191 deleted the add-missing-pad-reshape-propagation branch December 1, 2025 20:46
kcloudy0717 pushed a commit to kcloudy0717/llvm-project that referenced this pull request Dec 4, 2025
The existing `FoldPadWithProducerReshapeOpByExpansion` and
`FoldPadWithProducerReshapeOpByCollapsing` patterns did not cover all
reshape propagation cases, because they only consider cases where the
pad op is the consumer operation. This PR adds 2 new patterns to cover
the cases where the pad op is the producer operation, which completes
the propagation pattern set for pad op with expand_shape and
collapse_shape.

Note for integration: This PR also removes the single user restriction
for the `FoldPadWithProducerReshapeOpByExpansion` and
`FoldPadWithProducerReshapeOpByCollapsing` patterns, which leaves more
control to the users of the pattern. If this constraint is needed, then
it should be added to the control function for these patterns.

---------

Signed-off-by: Max Dawkins <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants